{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import config\n",
    "from dataloader.loader import Loader\n",
    "from preprocessing.utils import Preprocess, remove_empty_docs\n",
    "from dataloader.embeddings import GloVe\n",
    "from model.cnn_document_model import DocumentModel, TrainingParameters\n",
    "from keras.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import numpy as np\n",
    "from keras.utils import to_categorical\n",
    "import keras.backend as K\n",
    "\n",
    "\n",
    "from sklearn.manifold import TSNE\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data Sets for 20 News Group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dataset = Loader.load_20newsgroup_data(subset='train')\n",
    "corpus, labels = dataset.data, dataset.target\n",
    "corpus, labels = remove_empty_docs(corpus, labels)\n",
    "\n",
    "\n",
    "test_dataset = Loader.load_20newsgroup_data(subset='test')\n",
    "test_corpus, test_labels = test_dataset.data, test_dataset.target\n",
    "test_corpus, test_labels = remove_empty_docs(test_corpus, test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mapping 20 Groups to 6 High level Categories "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "six_groups = {\n",
    "    'comp.graphics':0,'comp.os.ms-windows.misc':0,'comp.sys.ibm.pc.hardware':0,\n",
    "    'comp.sys.mac.hardware':0, 'comp.windows.x':0,\n",
    "    \n",
    "    'rec.autos':1, 'rec.motorcycles':1, 'rec.sport.baseball':1, 'rec.sport.hockey':1,\n",
    "    \n",
    "    'sci.crypt':2, 'sci.electronics':2,'sci.med':2, 'sci.space':2,\n",
    "    \n",
    "    'misc.forsale':3,\n",
    "    \n",
    "    'talk.politics.misc':4, 'talk.politics.guns':4, 'talk.politics.mideast':4,\n",
    "    \n",
    "    'talk.religion.misc':5, 'alt.atheism':5, 'soc.religion.christian':5\n",
    "    \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "map_20_2_6 = [six_groups[dataset.target_names[i]] for i in range(20)]\n",
    "labels = [six_groups[dataset.target_names[i]] for i in labels] \n",
    "test_labels = [six_groups[dataset.target_names[i]] for i in test_labels] "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-process Text to convert it to word index sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "Preprocess.MIN_WD_COUNT=5\n",
    "preprocessor = Preprocess(corpus=corpus)\n",
    "corpus_to_seq = preprocessor.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "test_corpus_to_seq = preprocessor.transform(test_corpus)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "glove=GloVe(50)\n",
    "initial_embeddings = glove.get_embedding(preprocessor.word_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "newsgrp_model = DocumentModel(vocab_size=preprocessor.get_vocab_size(),\n",
    "                                    sent_k_maxpool = 5,\n",
    "                                    sent_filters = 20,\n",
    "                                    word_kernel_size = 5,\n",
    "                                    word_index = preprocessor.word_index,\n",
    "                                    num_sentences=Preprocess.NUM_SENTENCES,                                    \n",
    "                                    embedding_weights=initial_embeddings,\n",
    "                                    conv_activation = 'relu',\n",
    "                                    train_embedding = True,\n",
    "                                    learn_word_conv = True,\n",
    "                                    learn_sent_conv = True,\n",
    "                                    sent_dropout = 0.4,\n",
    "                                    hidden_dims=64,                                        \n",
    "                                    input_dropout=0.2, \n",
    "                                    hidden_gaussian_noise_sd=0.5,\n",
    "                                    final_layer_kernel_regularizer=0.1,\n",
    "                                    num_hidden_layers=2,\n",
    "                                    num_units_final_layer=6)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_params = TrainingParameters('6_newsgrp_largeclass', \n",
    "                                  model_file_path = config.MODEL_DIR+ '/20newsgroup/model_6_01.hdf5',\n",
    "                                  model_hyper_parameters = config.MODEL_DIR+ '/20newsgroup/model_6_01.json',\n",
    "                                  model_train_parameters = config.MODEL_DIR+ '/20newsgroup/model_6_01_meta.json',\n",
    "                                  num_epochs=20,\n",
    "                                  batch_size = 128,\n",
    "                                  validation_split=.10,\n",
    "                                  learning_rate=0.01)\n",
    "\n",
    "train_params.save()\n",
    "newsgrp_model._save_model(train_params.model_hyper_parameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compile and run model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "newsgrp_model._model.compile(loss=\"categorical_crossentropy\", \n",
    "                            optimizer=train_params.optimizer,\n",
    "                            metrics=[\"accuracy\"])\n",
    "checkpointer = ModelCheckpoint(filepath=train_params.model_file_path,\n",
    "                                   verbose=1,\n",
    "                                   save_best_only=True,\n",
    "                                   save_weights_only=True)\n",
    "\n",
    "early_stop = EarlyStopping(patience=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "x_train = np.array(corpus_to_seq)\n",
    "y_train  = to_categorical(np.array(labels))\n",
    "\n",
    "x_test = np.array(test_corpus_to_seq)\n",
    "y_test = to_categorical(np.array(test_labels))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#Set LR\n",
    "K.set_value(newsgrp_model.get_classification_model().optimizer.lr, train_params.learning_rate)\n",
    "\n",
    "newsgrp_model.get_classification_model().fit(x_train, y_train, \n",
    "                      batch_size=train_params.batch_size, \n",
    "                      epochs=train_params.num_epochs,\n",
    "                      verbose=2,\n",
    "                      validation_split=train_params.validation_split,\n",
    "                      callbacks=[checkpointer,early_stop])\n",
    "\n",
    "newsgrp_model.get_classification_model().evaluate( x_test, y_test,  verbose=2)\n",
    "preds = newsgrp_model.get_classification_model().predict(x_test)\n",
    "preds_test = np.argmax(preds, axis=1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Model Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report,accuracy_score,confusion_matrix\n",
    "print(classification_report(test_labels, preds_test))\n",
    "print(confusion_matrix(test_labels, preds_test))\n",
    "print(accuracy_score(test_labels, preds_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization: Document Embeddings with tsne - what the model learned"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from utils import scatter_plot\n",
    "doc_embeddings = newsgrp_model.get_document_model().predict(x_test)\n",
    "print(doc_embeddings.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "doc_proj = TSNE(n_components=2, random_state=42, ).fit_transform(doc_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "f, ax, sc, txts = scatter_plot(doc_proj, np.array(test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f.savefig('nws_grp_embd.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
